"""
Gulf of Mexico SST Missingness Audit
======================================
Computes per-year and aggregate missingness statistics across the
3-year Aqua MODIS SST benchmark dataset (2022-2024).

For each of the 956 daily CSV files produced by download_ocean_nasa.py,
this script computes:
  - Daytime SST missingness rate (% of grid cells cloud-masked)
  - Nighttime SST missingness rate (where available)
  - Daily spatial variance: sigma^2(t) = mean( (SST - SST_bar)^2 )
  - Geographic bounds verification

These statistics directly support the three EDA findings in:
  "When the Gap Is the Signal: A Data-Centric Assumption Audit
   of Sea Surface Temperature Reconstruction Methods"

Output:
  Printed summary table by year + multi-year aggregate
  (redirect stdout to a file to save results)

Usage:
    python analyze_all_years.py
    python analyze_all_years.py > audit_results.txt

Data Expected:
    Daily CSV files named louisiana_YYYY-MM-DD.csv
    Located in: data/ocean_l3_MNAR_3YR/ (relative to this script)

Authors:
    Pujit Naga Sai Pavan Kumar Etha
    Louisiana Tech University
    Dataset: Gulf of Mexico SST MNAR Benchmark (2022-2024)
"""

import os
import pandas as pd
import numpy as np
from glob import glob

# ── Data Directory ────────────────────────────────────────────────────────────
# Relative path from this script's location — works on any machine
DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "ocean_l3_MNAR_3YR")


def get_daytime_sst_col(df):
    """
    Identifies the daytime SST column in a daily Aqua MODIS CSV.
    The column name encodes the product type (L3m.DAY.SST.sst.4km).
    """
    for col in df.columns:
        if 'L3m.DAY.SST.sst.4km' in col:
            return col
    # Fallback: look for any column with 'SST' in the name
    for col in df.columns:
        if col.upper() == 'SST':
            return col
    return None


def get_nighttime_sst_col(df):
    """
    Identifies the nighttime SST column (NSST or SST4 product) if present.
    """
    for col in df.columns:
        if 'DAY.NSST.sst.4km' in col:
            return col
    for col in df.columns:
        if 'DAY.SST4.sst4.4km' in col and 'NRT' not in col:
            return col
    return None


def compute_spatial_variance(df, sst_col):
    """
    Computes daily spatial variance across all observed (non-NaN) grid cells.
    sigma^2_spatial(t) = mean( (SST_s,t - SST_bar_t)^2 ) for s in S_valid(t)

    This is the metric used in Finding 3 (temporal heteroscedasticity).
    """
    valid = df[sst_col].dropna()
    if len(valid) < 2:
        return np.nan
    return float(np.var(valid, ddof=1))


def analyze_all_years():
    all_files = sorted(glob(os.path.join(DATA_DIR, "*.csv")))

    if not all_files:
        print(f"No CSV files found in: {DATA_DIR}")
        print("Run download_ocean_nasa.py first to generate the dataset.")
        return

    print(f"Found {len(all_files)} daily files in dataset.")
    print(f"Coverage: {os.path.basename(all_files[0])} to {os.path.basename(all_files[-1])}\n")

    stats_by_year   = {}
    overall_day_total   = 0
    overall_day_missing = 0
    overall_night_total   = 0
    overall_night_missing = 0
    spatial_variances   = []

    lat_min, lat_max = float('inf'),  float('-inf')
    lon_min, lon_max = float('inf'),  float('-inf')

    for i, filepath in enumerate(all_files):
        if i % 100 == 0 and i > 0:
            print(f"  Processed {i}/{len(all_files)} files...")

        # Parse year from filename: louisiana_YYYY-MM-DD.csv
        basename = os.path.basename(filepath)
        try:
            date_str = basename.replace("louisiana_", "").replace(".csv", "")
            year = date_str[:4]
        except Exception:
            year = "Unknown"

        if year not in stats_by_year:
            stats_by_year[year] = {
                'file_count':   0,
                'day_total':    0, 'day_missing':   0, 'day_rates':   [],
                'night_total':  0, 'night_missing': 0, 'night_rates': [],
                'spatial_variances': []
            }

        df = pd.read_csv(filepath, low_memory=False)
        stats_by_year[year]['file_count'] += 1

        # Update geographic bounds
        if 'Lat' in df.columns and 'Lon' in df.columns:
            lat_min = min(lat_min, df['Lat'].min())
            lat_max = max(lat_max, df['Lat'].max())
            lon_min = min(lon_min, df['Lon'].min())
            lon_max = max(lon_max, df['Lon'].max())

        total_rows = len(df)

        # Daytime SST missingness
        day_col = get_daytime_sst_col(df)
        if day_col and total_rows > 0:
            n_missing = int(df[day_col].isnull().sum())
            miss_rate = n_missing / total_rows * 100
            stats_by_year[year]['day_total']   += total_rows
            stats_by_year[year]['day_missing'] += n_missing
            stats_by_year[year]['day_rates'].append(miss_rate)
            overall_day_total   += total_rows
            overall_day_missing += n_missing

            # Daily spatial variance (Finding 3)
            sv = compute_spatial_variance(df, day_col)
            if not np.isnan(sv):
                stats_by_year[year]['spatial_variances'].append(sv)
                spatial_variances.append(sv)

        # Nighttime SST missingness
        night_col = get_nighttime_sst_col(df)
        if night_col and total_rows > 0:
            n_missing_n = int(df[night_col].isnull().sum())
            stats_by_year[year]['night_total']   += total_rows
            stats_by_year[year]['night_missing'] += n_missing_n
            stats_by_year[year]['night_rates'].append(n_missing_n / total_rows * 100)
            overall_night_total   += total_rows
            overall_night_missing += n_missing_n

    # ── Print Results ─────────────────────────────────────────────────────────
    print("\n" + "=" * 60)
    print("GEOGRAPHIC BOUNDS (verification)")
    print(f"  Latitude:  {lat_min:.3f}N to {lat_max:.3f}N")
    print(f"  Longitude: {lon_min:.3f}E to {lon_max:.3f}E")
    print("=" * 60)

    for year in sorted(stats_by_year.keys()):
        d = stats_by_year[year]
        print(f"\n[{year}]  {d['file_count']} daily files")
        print("-" * 40)

        if d['day_total'] > 0:
            pct = d['day_missing'] / d['day_total'] * 100
            rates = np.array(d['day_rates'])
            print(f"  Daytime SST Missingness:")
            print(f"    Overall:              {pct:.2f}%  ({d['day_missing']:,} / {d['day_total']:,} cells)")
            print(f"    Daily min/median/max: {np.nanmin(rates):.1f}% / {np.nanmedian(rates):.1f}% / {np.nanmax(rates):.1f}%")
            print(f"    Days > 80% missing:   {int(np.sum(rates > 80))}")
            print(f"    Days < 5% missing:    {int(np.sum(rates < 5))}")

        if d['spatial_variances']:
            svs = np.array(d['spatial_variances'])
            print(f"  Spatial Variance (Finding 3 — temporal heteroscedasticity):")
            print(f"    min / median / max:   {np.nanmin(svs):.3f} / {np.nanmedian(svs):.3f} / {np.nanmax(svs):.3f}  (C^2)")

        if d['night_total'] > 0:
            pct_n = d['night_missing'] / d['night_total'] * 100
            rates_n = np.array(d['night_rates'])
            print(f"  Nighttime SST Missingness:")
            print(f"    Overall:              {pct_n:.2f}%  ({d['night_missing']:,} / {d['night_total']:,} cells)")
            print(f"    Daily min/median/max: {np.nanmin(rates_n):.1f}% / {np.nanmedian(rates_n):.1f}% / {np.nanmax(rates_n):.1f}%")

    print("\n" + "=" * 60)
    print("MULTI-YEAR AGGREGATE (2022-2024)")
    print("=" * 60)

    if overall_day_total > 0:
        overall_pct = overall_day_missing / overall_day_total * 100
        print(f"  Total Daytime Missingness: {overall_pct:.2f}%  ({overall_day_missing:,} / {overall_day_total:,} cells)")

    if overall_night_total > 0:
        overall_pct_n = overall_night_missing / overall_night_total * 100
        print(f"  Total Nighttime Missingness: {overall_pct_n:.2f}%  ({overall_night_missing:,} / {overall_night_total:,} cells)")

    if spatial_variances:
        svs_all = np.array(spatial_variances)
        print(f"\n  Spatial Variance across all days:")
        print(f"    min:    {np.nanmin(svs_all):.3f} C^2")
        print(f"    median: {np.nanmedian(svs_all):.3f} C^2")
        print(f"    max:    {np.nanmax(svs_all):.3f} C^2")
        print(f"    ratio (max/min): {np.nanmax(svs_all)/np.nanmin(svs_all):.1f}x  [reported in paper as Finding 3]")

    print("=" * 60)
    print("\nNote: NaN values in SST column represent cloud-masked pixels (MNAR preserved).")
    print("These statistics support the EDA findings in the DMLR submission.")


if __name__ == "__main__":
    analyze_all_years()